﻿using System.Security.Claims;

namespace Microsoft.EntityFrameworkCore;

public static class DbContextExtensions
{
    public static User CurrentUser(this DbContext ctx)
    {
        var currentUser = App.HttpContext!.User.Claims.FirstOrDefault(n => n.Type is "name" or ClaimTypes.Name);

        if (currentUser is not null)
        {
            return ctx.Set<User>().First(n => n.Account == currentUser.Value);
        }
        else
        {
            if (App.Metadata.TryGetValue("operator", out var userId))
            {
                return ctx.Set<User>().First(n => n.UserId == userId);
            }
        }

        throw new InvalidOperationException("当前上下文中不存在用户信息");
    }

    public static async Task<User> CurrentUserAsync(this DbContext ctx, CancellationToken cancellationToken = default)
    {
        var currentUser = App.HttpContext!.User.Claims.FirstOrDefault(n => n.Type is "name" or ClaimTypes.Name);

        if (currentUser is not null)
        {
            return await ctx.Set<User>().FirstAsync(n => n.Account == currentUser.Value, cancellationToken: cancellationToken);
        }
        else
        {
            if (App.Metadata.TryGetValue("operator", out var userId))
            {
                return await ctx.Set<User>().FirstAsync(n => n.UserId == userId, cancellationToken: cancellationToken);
            }
        }

        throw new InvalidOperationException("当前上下文中不存在用户信息");
    }

    public static string GetFieldCorrespondingValue(this DbContext ctx, string entityTypeName, int entityId, string field, string value)
    {
        if (value is "null")
        {
            return value;
        }

        try
        {
            var entityType = App.Assembly.ExportedTypes.First(n => n.Name == entityTypeName);

            var method = typeof(DbSet<>).MakeGenericType(entityType).GetMethod("Find")!;
            var fieldInfo = entityType.GetProperty(field)!;

            var principaldbSetInstance = ctx.GetType().GetProperty(entityTypeName)!.GetValue(ctx);

            var principalInstance = method.Invoke(principaldbSetInstance, new object[] { new object[] { entityId } });

            if (fieldInfo.PropertyType.GetInterfaces().Contains(typeof(IEnumerable)))
            {
                if (fieldInfo.PropertyType.GenericTypeArguments.Length > 0 && fieldInfo.PropertyType.GenericTypeArguments[0].IsClass)
                {
                    var relaEntity = fieldInfo.PropertyType.GenericTypeArguments[0];
                    var innerDbSetInstance = ctx.GetType().GetProperty(relaEntity.Name)!.GetValue(ctx);
                    var innerMethodInfo = typeof(DbSet<>).MakeGenericType(relaEntity).GetMethod("Find")!;

                    var currentInstance = innerMethodInfo.Invoke(innerDbSetInstance, new object[] { new object[] { int.Parse(value) } });

                    var nameProp = relaEntity.GetProperty("Name");

                    return nameProp is null ? value : nameProp.GetValue(currentInstance)?.ToString()!;
                }
            }
            else
            {
                return value;
            }
        }
        catch
        {
            return value;
        }

        return value;
    }

    public static DbContext SetTracking<TWrapperContext>(this DbContext context, bool allowTracking = true)
        where TWrapperContext : DbContext
    {
        var prop = context.GetType().GetProperty("AllowTracking");

        prop?.SetValue(context, allowTracking);

        return context;
    }
}